{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TVM Pass Instrument\n",
    "\n",
    "参考：[如何使用 TVM Pass Instrument](https://xinetzone.github.io/tvm/docs/how_to/extend_tvm/use_pass_instrument.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.relay as relay\n",
    "from tvm.relay.testing import resnet\n",
    "from tvm.contrib.download import download_testdata\n",
    "from tvm.relay.build_module import bind_params_by_name\n",
    "from tvm.ir.instrument import (\n",
    "    PassTimingInstrument,\n",
    "    pass_instrument,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "num_of_image_class = 1000\n",
    "image_shape = (3, 224, 224)\n",
    "output_shape = (batch_size, num_of_image_class)\n",
    "relay_mod, relay_params = resnet.get_workload(num_layers=18, batch_size=1, image_shape=image_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Printing results of timing profile...\n",
      "InferType: 11228us [11228us] (53.85%; 53.85%)\n",
      "FoldScaleAxis: 9621us [7us] (46.15%; 46.15%)\n",
      "\tFoldConstant: 9614us [2007us] (46.11%; 99.92%)\n",
      "\t\tInferType: 7607us [7607us] (36.49%; 79.13%)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "timing_inst = PassTimingInstrument()\n",
    "with tvm.transform.PassContext(instruments=[timing_inst]):\n",
    "    relay_mod = relay.transform.InferType()(relay_mod)\n",
    "    relay_mod = relay.transform.FoldScaleAxis()(relay_mod)\n",
    "    # 在退出上下文之前，获取 profile 结果。\n",
    "    profiles = timing_inst.render()\n",
    "print(\"Printing results of timing profile...\")\n",
    "print(profiles)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('torch': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "20e538bd0bbffa4ce75068aaf85df10d4944f3fdb705eeec6781a4702773116f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
